import arcpy, os
from arcpy.sa import *


class Toolbox(object):
    def __init__(self):
        """Define the toolbox (the name of the toolbox is the name of the
        .pyt file)."""
        self.label = "Image Processing"
        self.alias = ""

        # List of tool classes associated with this toolbox
        self.tools = [TasseledCap]


class TasseledCap(object):
    def __init__(self):
        """Define the tool (tool name is the name of the class)."""
        self.label = "Tasseled Cap"
        self.description = ""
        self.canRunInBackground = False

    def getParameterInfo(self):
        """Define parameter definitions"""

        inRaster = arcpy.Parameter(
            displayName = "Input Landsat Image",
            name = "in_raster",
            datatype = "DERasterDataset",
            parameterType = "Required",
            direction = "Input")

        outRaster = arcpy.Parameter(
            displayName = "Output Tasseled Cap Image",
            name = "out_raster",
            datatype = "DERasterDataset",
            parameterType = "Required",
            direction = "Output")

        
        
        params = [inRaster, outRaster]
        return params

    def isLicensed(self):
        """Set whether tool is licensed to execute."""
	# NOTE: this does not actually "check out" the license
	# see the execute routine below.

        try:
            if arcpy.CheckExtension("Spatial") != "Available":
                raise Exception

        except Exception:
            arcpy.AddError("Spatial Analyst Extension not available")
            return False

        return True

    def updateParameters(self, parameters):
        """Modify the values and properties of parameters before internal
        validation is performed.  This method is called whenever a parameter
        has been changed."""

        

        
        return

    def updateMessages(self, parameters):
        """Modify the messages created by internal validation for each tool
        parameter.  This method is called after internal validation."""


        if parameters[0].altered:

            numbands = arcpy.Describe(parameters[0].valueAsText).bandCount

            if numbands <> 6:
            
                parameters[0].setErrorMessage("Number of bands is not 6. Image must contain 6 bands")
                
        return

    def execute(self, parameters, messages):
        """The source code of the tool."""
	

        # set overwrite output to True to overwrite intermediate and final output
        # if the files exist
	arcpy.env.overwriteOutput = True

        # split out the path and file to access the path for intermediate image files
	outpath, img = os.path.split(parameters[1].valueAsText)
	#arcpy.AddMessage(outpath)

	# Check out the Spatial Analysis Extension
	arcpy.CheckOutExtension("spatial")

	
	arcpy.AddMessage("Processing the image...")


        # the bright, green, and wet bands will always be overwritten
        # each time the script executes
	brightbnd = os.path.join(outpath, 'bright1.img')
        greenbnd = os.path.join(outpath, 'green1.img')
        wetbnd = os.path.join(outpath, 'wet1.img')

        # the "\\" is required to access a specific band (i.e. layer)
        in_image = parameters[0].valueAsText + "\\" # the input image    

        # the names of the "bands" in the file structure
        # this is standard naming for ArcGIS
        bands = ['Layer_1',
                 'Layer_2',
                 'Layer_3',
                 'Layer_4',
                 'Layer_5',
                 'Layer_6']

        tasseled_cap = parameters[1].valueAsText # the output image

        print 'Running tasseled cap'

        # Order of coefficients are TM1, TM2, TM3, TM4, TM5, TM7
        # where 1-7 are the Landsat TM band numbers

        bright = [0.2909, 
                  0.2493,  
                  0.4806,
                  0.5568,
                  0.4438,
                  0.1706]

        green = [-0.2728,
                 -0.2174,
                 -0.5508,
                 0.7221,
                 0.0733,
                 -0.1648]

        wet = [0.1466,
               0.1761,
               0.3322,
               0.3396,
               -0.6210,
               -0.4186]

        tc_bright = Float(Raster(in_image + bands[0]) * bright[0] + \
                    Raster(in_image + bands[1]) * bright[1] + \
                    Raster(in_image + bands[2]) * bright[2] + \
                    Raster(in_image + bands[3]) * bright[3] + \
                    Raster(in_image + bands[4]) * bright[4] + \
                    Raster(in_image + bands[5]) * bright[5])

        tc_bright.save(brightbnd)

        tc_green = Float(Raster(in_image + bands[0]) * green[0] + \
                    Raster(in_image + bands[1]) * green[1] + \
                    Raster(in_image + bands[2]) * green[2] + \
                    Raster(in_image + bands[3]) * green[3] + \
                    Raster(in_image + bands[4]) * green[4] + \
                    Raster(in_image + bands[5]) * green[5])

        tc_green.save(greenbnd)

        tc_wet = Float(Raster(in_image + bands[0]) * wet[0] + \
                    Raster(in_image + bands[1]) * wet[1] + \
                    Raster(in_image + bands[2]) * wet[2] + \
                    Raster(in_image + bands[3]) * wet[3] + \
                    Raster(in_image + bands[4]) * wet[4] + \
                    Raster(in_image + bands[5]) * wet[5])

        tc_wet.save(wetbnd)


        #Build the image using Composite Bands
        arcpy.CompositeBands_management(brightbnd + ';' + \
                                        greenbnd + ';' + \
                                        wetbnd, tasseled_cap)

	arcpy.AddMessage("Completed Tasseled Cap.")

	# Check in the Spatial Analysis Extension for others to use
	arcpy.CheckInExtension("spatial")

        
        

        
        return